Apply use_reentrant removal to all TRL trainer configs#4321
Conversation
The existing fix that removes use_reentrant=False from gradient_checkpointing_kwargs was gated behind RLConfig_name == "GRPOConfig", so only GRPOConfig was protected. SFTConfig, DPOConfig, KTOConfig, CPOConfig, ORPOConfig etc. were all still affected. Remove the GRPOConfig guard so the fix applies to all compiled trainer configs when TRL >= 0.27.0. This is defense-in-depth alongside the unsloth_zoo fix that forces use_reentrant=True in unsloth_checkpoint() itself.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request expands a critical VRAM regression fix to cover all TRL trainer configurations. Previously, the fix was inadvertently limited to only Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly expands a fix for a VRAM regression to all TRL trainer configurations. Previously, the patch to remove use_reentrant=False from gradient_checkpointing_kwargs in TRL versions 0.27.0+ was limited to GRPOConfig. By removing this restriction, the fix now applies to all trainer configs, ensuring consistent behavior and preventing the regression in a wider range of scenarios. The change is safe, as it is guarded by the TRL version check and the patch itself defensively checks for the existence of the attribute and key before attempting deletion.
Summary
The existing fix in
rl.pythat removesuse_reentrant=Falsefromgradient_checkpointing_kwargs(added for TRL 0.27.0+) was gated behindRLConfig_name == "GRPOConfig". This meant only GRPOConfig was protected from the VRAM regression, while SFTConfig, DPOConfig, KTOConfig, CPOConfig, ORPOConfig and all other trainer configs were still affected.Change
Remove the
GRPOConfigguard so theuse_reentrantremoval applies to all compiled trainer configs when TRL >= 0.27.0.Why this is safe
trl_version >= Version("0.27.0")is preserved, so TRL < 0.27.0 is unaffecteduse_reentrantfromgradient_checkpointing_kwargsif it exists, usinggetattrandinchecksuse_reentrant=Trueinunsloth_checkpoint()itselfCompatibility
use_reentrantis not set)